LSTM
一种常用的 循环神经网络(RNN) 模块,用于处理具有时序依赖特征的数据(如语音、文本、时间序列等)。每个时间步的公式化描述如下。
\(x_t\) : 当前时间步输入向量
\(h_{t-1}\) : 上一时间步的隐藏状态
\(c_{t-1}\) : 上一时间步的细胞状态
\(i_t, f_t, g_t, o_t\) : 四个门(输入门、遗忘门、候选门、输出门)
\(W_*\) : 对应的权重矩阵
\(b_*\) : 偏置项
\(\sigma(\cdot)\) : Sigmoid 函数
\(\odot\) : 元素乘
- 输入:
input - 输入序列数据,形状为 \((seq\_len, batch, input\_size)\),即每个时间步的输入特征。
- params - 静态参数数组,包含 LSTM 网络配置、权重、状态指针等。
weight_i - 输入到各门 \((input, forget, cell, output)\) 的权重矩阵,大小为 4 * hidden_size * input_size。
weight_h - 上一隐藏状态到各门的权重矩阵,大小为 \(4 * hidden\_size * hidden\_size\)
input_bias - 输入部分的偏置项,对应 4 个门的偏置。
state_bias - 隐藏状态部分的偏置项(也是 \(4 * hidden\_size\)),与 input_bias 一起求和形成总偏置。
hidden_state - 当前批次初始隐藏状态输入( \(h_0\) ),执行后更新为最后时刻的隐藏状态输出( \(h_t\))
cell_state - 当前批次初始细胞状态输入( \(c_0\)),执行后更新为最后时刻的细胞状态输出( \(c_t\))。
buffer - 临时工作区指针数组(中间计算缓存,如门值、激活结果、临时矩阵等,用于优化性能)。
LstmParameter - LSTM 配置参数结构体,包含输入大小、隐藏层维度、序列长度、是否双向等信息。
core_mask - 核掩码(仅适用于共享存储版本)。
LstmParameter定义:
1typedef struct LstmParameter {
2int input_size_;//每个时间步输入向量的维度(输入特征数)。
3int hidden_size_;//LSTM 隐藏状态的维度(每个门的内部计算大小)。
4int project_size_;//投影层输出维度(用于 LSTMP,有则在输出前线性压缩隐藏状态)。
5int output_size_;//实际输出维度,等于 hidden_size_ 或 project_size_(取决于是否使用投影层)。
6int seq_len_;//输入序列的时间步数(序列长度)。
7int batch_;//批次大小(一次处理的样本数量)。
8// other parameter
9int output_step_;//指定输出第几个时间步的结果(通常为最后一步或每步)。
10bool bidirectional_;//是否为双向 LSTM(true 表示前向和后向各一层)。
11float zoneout_cell_;//单元状态的 Zoneout 比例(防止过拟合的正则化参数)。
12float zoneout_hidden_;//隐藏状态的 Zoneout 比例(防止过拟合)。
13int input_row_align_;//输入张量的行对齐参数(用于 DMA 或 SIMD 加速的内存对齐)。
14int input_col_align_;//输入张量的列对齐参数。
15int state_row_align_;//状态张量(hidden/cell)的行对齐参数。
16int state_col_align_;//状态张量的列对齐参数。
17int proj_col_align_;//投影层矩阵的列对齐参数。
18bool has_bias_;//是否包含偏置项(true 表示使用 bias)。
19} LstmParameter;
输出:
output - 计算结果地址,存放 LSTM 每个时间步输出结果的缓冲区,维度通常为 \((seq\_len, batch, output\_size)\)
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32
MT7004 支持fp32、fp16
共享存储版本:
-
void fp_lstm_s(float *output, const float *input, unsigned long long *params, int core_mask);
-
void hp_lstm_s(half *output, const half *input, unsigned long long *params, int core_mask);
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <lstm.h>
4
5int main(int argc, char* argv[]) {
6 LstmParameter *lstm_param = (LstmParameter *)0x90000000;
7 lstm_param->seq_len_ = 4;
8 lstm_param->batch_ = 1;
9 lstm_param->input_size_ = 2;
10 lstm_param->hidden_size_ = 3;
11 lstm_param->bidirectional_ = false;
12 float * input = (float *)0xA0000000;
13 float * weight_i = (float *)0xA0001000;
14 float * weight_h = (float *)0xA0003000;
15 float * bias = (float *)0xA0005000;
16 float *hidden_state = (float *)0xA0006000;
17 float *cell_state = (float *)0xA0007000;
18 float *buffer[9];
19 float * packed_input_ = (float *)0xB0000000;
20 buffer[0] = packed_input_;
21 float * gate = (float *)0xB0100000;
22 buffer[1] = gate;
23 float * packed_state = (float *)0xB0200000;
24 buffer[2] = packed_state;
25 float * state_gate = (float *)0xB0300000;
26 buffer[3] = state_gate;
27 float * cell_buffer = (float *)0xB0400000;
28 buffer[4] = cell_buffer;
29 float * hidden_buffer = (float *)0xB0500000;
30 buffer[5] = hidden_buffer;
31 float * packed_output = (float *)0xB0600000;
32 buffer[6] = packed_output;
33 float * left_matrix = (float *)0xB0700000;
34 buffer[7] = left_matrix;
35 float * packed_ptr = (float *)0xB0800000;
36 buffer[8] = packed_ptr;
37 lstm_param->hidden_size_ = 3;
38 lstm_param->output_size_ = 3;
39
40 lstm_param->output_step_ = lstm_param->bidirectional_ ? 2 * lstm_param->batch_ * lstm_param->output_size_
41 : lstm_param->batch_ * lstm_param->output_size_;
42 int weight_segment_num_ = lstm_param->bidirectional_ ? 2 * 4 : 4;
43 int row_tile_ = 12;
44 int col_tile_ = 8;
45 lstm_param->input_row_align_ = UP_ROUND(lstm_param->seq_len_ * lstm_param->batch_, row_tile_);
46 lstm_param->input_col_align_ = UP_ROUND(lstm_param->hidden_size_, col_tile_);
47 int state_row_tile_ = row_tile_;
48 int state_col_tile_ = col_tile_;
49 lstm_param->state_row_align_ = lstm_param->batch_ == 1 ? 1 : UP_ROUND(lstm_param->batch_, state_row_tile_);
50 lstm_param->state_col_align_ =
51 lstm_param->batch_ == 1 ? lstm_param->hidden_size_ : UP_ROUND(lstm_param->hidden_size_, state_col_tile_);
52 lstm_param->proj_col_align_ =
53 lstm_param->batch_ == 1 ? lstm_param->output_size_ : UP_ROUND(lstm_param->output_size_, state_col_tile_);
54 unsigned long long params[9];
55 params[0] = (unsigned long long)weight_i;
56 params[1] = (unsigned long long)weight_h;
57 params[2] = (unsigned long long)input_bias_;//ok
58 params[3] = (unsigned long long)state_bias;
59 params[4] = (unsigned long long)hidden_state;
60 params[5] = (unsigned long long)cell_state;
61 params[6] = (unsigned long long)buffer;
62 params[7] = (unsigned long long)lstm_param;
63 int core_mask = 0x11;
64
65 fp_lstm_s(output, input, params, core_mask);
66 return 0;
67}
私有存储版本:
-
void fp_lstm_p(float *output, const float *input, unsigned long long *params);
-
void hp_lstm_p(half *output, const half *input, unsigned long long *params);
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <lstm.h>
4int main(int argc, char* argv[]) {
5 LstmParameter *lstm_param = (LstmParameter *)0x10810000;
6 lstm_param->seq_len_ = 4;
7 lstm_param->batch_ = 1;
8 lstm_param->input_size_ = 2;
9 lstm_param->hidden_size_ = 3;
10 lstm_param->bidirectional_ = false;
11 float * input = (float *)0x10811000;
12 float * weight_i = (float *)0x10812000;
13 float * weight_h = (float *)0x10813000;
14 float * bias = (float *)0x10814000;
15 float *hidden_state = (float *)0x10814800;
16 float *cell_state = (float *)0x10815000;
17 float *buffer[9];
18 float * packed_input_ = (float *)0x10815f00;
19 buffer[0] = packed_input_;
20 float * gate = (float *)0x108160000;
21 buffer[1] = gate;
22 float * packed_state = (float *)0x10816100;
23 buffer[2] = packed_state;
24 float * state_gate = (float *)0x10816200;
25 buffer[3] = state_gate;
26 float * cell_buffer = (float *)0x10816300;
27 buffer[4] = cell_buffer;
28 float * hidden_buffer = (float *)0x10816400;
29 buffer[5] = hidden_buffer;
30 float * packed_output = (float *)0x10816500;
31 buffer[6] = packed_output;
32 float * left_matrix = (float *)0x10816600;
33 buffer[7] = left_matrix;
34 float * packed_ptr = (float *)0x10816700;
35 buffer[8] = packed_ptr;
36 lstm_param->hidden_size_ = 3;
37 lstm_param->output_size_ = 3;
38
39 lstm_param->output_step_ = lstm_param->bidirectional_ ? 2 * lstm_param->batch_ * lstm_param->output_size_
40 : lstm_param->batch_ * lstm_param->output_size_;
41 int weight_segment_num_ = lstm_param->bidirectional_ ? 2 * 4 : 4;
42 int row_tile_ = 12;
43 int col_tile_ = 8;
44 lstm_param->input_row_align_ = UP_ROUND(lstm_param->seq_len_ * lstm_param->batch_, row_tile_);
45 lstm_param->input_col_align_ = UP_ROUND(lstm_param->hidden_size_, col_tile_);
46 int state_row_tile_ = row_tile_;
47 int state_col_tile_ = col_tile_;
48 lstm_param->state_row_align_ = lstm_param->batch_ == 1 ? 1 : UP_ROUND(lstm_param->batch_, state_row_tile_);
49 lstm_param->state_col_align_ =
50 lstm_param->batch_ == 1 ? lstm_param->hidden_size_ : UP_ROUND(lstm_param->hidden_size_, state_col_tile_);
51 lstm_param->proj_col_align_ =
52 lstm_param->batch_ == 1 ? lstm_param->output_size_ : UP_ROUND(lstm_param->output_size_, state_col_tile_);
53 unsigned long long params[9];
54 params[0] = (unsigned long long)weight_i;
55 params[1] = (unsigned long long)weight_h;
56 params[2] = (unsigned long long)input_bias_;//ok
57 params[3] = (unsigned long long)state_bias;
58 params[4] = (unsigned long long)hidden_state;
59 params[5] = (unsigned long long)cell_state;
60 params[6] = (unsigned long long)buffer;
61 params[7] = (unsigned long long)lstm_param;
62 int core_mask = 0x11;
63
64 fp_lstm_s(output, input, params, core_mask);
65 return 0;
66}